{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "86d962f8-d1c9-4e23-9060-485d70999883",
   "metadata": {},
   "source": [
    "# 分类\n",
    "\n",
    "> 使用带有结构化输出的聊天模型将文本分类到不同的类别或标签中。\n",
    "\n",
    "## 将文本分类到标签中\n",
    "\n",
    "“打标签”（Tagging）是指使用以下类型的类别对文档进行标注：\n",
    "\n",
    "- **情感**（正面、负面、中性等）  \n",
    "- **语言**（中文、英文、法语等）  \n",
    "- **风格**（正式、非正式等）  \n",
    "- **涉及的主题**  \n",
    "- **政治倾向**\n",
    "\n",
    "\n",
    "![示例图片](../assets/imgs/tagging.png)  \n",
    "\n",
    "### 概述  \n",
    "**打标签（Tagging）** 包含几个组成部分：\n",
    "\n",
    "- **函数（function）**：和信息提取类似，打标签也使用函数来指定模型应该如何对文档进行标注  \n",
    "- **模式（schema）** ：定义了我们希望如何对文档进行打标签  \n",
    "\n",
    "---\n",
    "\n",
    "### 快速开始  \n",
    "下面我们来看一个非常简单的例子，展示如何在 LangChain 中使用 OpenAI 的工具调用功能来进行**文本打标签**。\n",
    "\n",
    "我们会使用 OpenAI 模型支持的 `with_structured_output` 方法。\n",
    "\n",
    "（注：这个方法可以让模型输出结构化的数据格式，比如 JSON，非常适合用来做分类、打标签等任务。）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "14eef82b-937f-48dd-ab6b-6237be0dbfaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "try:\n",
    "    # load environment variables from .env file (requires `python-dotenv`)\n",
    "    from dotenv import load_dotenv\n",
    "\n",
    "    _ = load_dotenv()\n",
    "except ImportError:\n",
    "    pass\n",
    "\n",
    "if not os.environ.get(\"DASHSCOPE_API_KEY\"):\n",
    "  os.environ[\"DASHSCOPE_API_KEY\"] = getpass.getpass(\"Enter API key for OpenAI: \")\n",
    "\n",
    "from langchain_community.chat_models.tongyi import ChatTongyi\n",
    "\n",
    "llm = ChatTongyi(\n",
    "    streaming=True,\n",
    "    name=\"qwen-turbo\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c0d6efc-4b5c-4427-9398-6189993a438e",
   "metadata": {},
   "source": [
    "我们可以在模式（schema）中指定一个 Pydantic 模型，包含一些属性及其预期的数据类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "bd1ed4ad-db39-43d2-abc6-7b486690d9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_openai import ChatOpenAI\n",
    "from pydantic import BaseModel, Field\n",
    "\n",
    "\n",
    "# 从以下段落中提取所需信息。\n",
    "# 仅提取“分类”函数中提到的属性。\n",
    "# 段落：\n",
    "\n",
    "tagging_prompt = ChatPromptTemplate.from_template(\n",
    "    \"\"\"\n",
    "Extract the desired information from the following passage.\n",
    "\n",
    "Only extract the properties mentioned in the 'Classification' function.\n",
    "\n",
    "Passage:\n",
    "{input}\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "class Classification(BaseModel):\n",
    "    # 文本的情感\n",
    "    sentiment: str = Field(description=\"The sentiment of the text\") \n",
    "    # 文字的攻击性程度（1 到 10 分\n",
    "    aggressiveness: int = Field(\n",
    "        description=\"How aggressive the text is on a scale from 1 to 10\"\n",
    "    )\n",
    "    # 文本所用的语言\n",
    "    language: str = Field(description=\"The language the text is written in\")\n",
    "\n",
    "# Structured LLM\n",
    "structured_llm = llm.with_structured_output(Classification)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "e1f41a19-31c6-45f6-ad21-2faf26bb96a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Classification(sentiment='negative', aggressiveness=3, language='zh')"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_content = \"哇，又迟到了，真是太棒了。\"\n",
    "prompt = tagging_prompt.invoke({\"input\": input_content})\n",
    "response = structured_llm.invoke(prompt)\n",
    "\n",
    "response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "313c81af-18ee-4669-a65e-a3d19dfe2c7a",
   "metadata": {},
   "source": [
    "如果我们希望得到字典格式的输出，只需调用 `.model_dump()` 方法即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "d06a2eaa-f25a-4e21-85ee-037b0dc87aca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentiment': 'positive', 'aggressiveness': 5, 'language': 'Chinese'}"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp = \"这个点子太‘绝’了。\"\n",
    "prompt = tagging_prompt.invoke({\"input\": inp})\n",
    "response = structured_llm.invoke(prompt)\n",
    "\n",
    "response.model_dump()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "092a4167-77ed-4bc1-94e6-71ae72e430ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sentiment': 'positive', 'aggressiveness': 5, 'language': 'Chinese'}\n"
     ]
    }
   ],
   "source": [
    "if hasattr(response, \"model_dump\"):\n",
    "    print(response.model_dump())\n",
    "else:\n",
    "    print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bbf5ee6-201c-409d-8f4d-9fa75de6842a",
   "metadata": {},
   "source": [
    "正如我们在示例中看到的那样，模型能正确理解我们的需求。\n",
    "\n",
    "输出的结果可能会有所不同，例如，情感标签可能是不同语言的（如英文的 `\"positive\"`、中文的 `\"积极\"` 等）。\n",
    "\n",
    "我们将在下面学习如何控制这些输出结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54221cce-b476-4b14-8553-655a3829486d",
   "metadata": {},
   "source": [
    "### 更精细的控制  \n",
    "通过仔细定义模式（schema），我们可以更好地控制模型的输出结果。\n",
    "\n",
    "具体来说，我们可以在模型中定义以下内容：\n",
    "\n",
    "- 每个属性的**可选值**\n",
    "- 属性的**描述信息**，确保模型准确理解每个字段的含义\n",
    "- 必须返回的**必填属性**\n",
    "\n",
    "下面我们重新定义我们的 Pydantic 模型，使用枚举（enum）来实现上述各个方面的控制："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "8bde1c3f-ea14-4e80-9560-81d1064fb2e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Classification(BaseModel):\n",
    "    sentiment: str = Field(..., enum=[\"开心\", \"中性\", \"悲伤\"])\n",
    "    aggressiveness: int = Field(\n",
    "        ...,\n",
    "        description=\"描述言论的攻击性，数字越高，攻击性越强\",\n",
    "        enum=[1, 2, 3, 4, 5],\n",
    "    )\n",
    "    language: str = Field(\n",
    "        ..., enum=[\"中文\", \"英文\", \"日本\", \"法国\", \"繁体中文\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "304d5c0c-fbc6-4dc1-8328-6f6da623a039",
   "metadata": {},
   "outputs": [],
   "source": [
    "tagging_prompt = ChatPromptTemplate.from_template(\n",
    "    \"\"\"\n",
    "Extract the desired information from the following passage.\n",
    "\n",
    "Only extract the properties mentioned in the 'Classification' function.\n",
    "\n",
    "Passage:\n",
    "{input}\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "chain = llm.with_structured_output(\n",
    "    Classification\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e04965f-f894-44cf-9b93-b6f478630dfd",
   "metadata": {},
   "source": [
    "现在，回答将会以我们期望的方式被限制（规范）！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "2fbec65c-69e9-47e3-b60d-d1a04431a15c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sentiment': '开心', 'aggressiveness': 3, 'language': '中文'}\n"
     ]
    }
   ],
   "source": [
    "inp = \"真是太棒了，连点外卖都能送错，客服还不接电话，简直五星好评！\"\n",
    "prompt = tagging_prompt.invoke({\"input\": inp})\n",
    "response = chain.invoke(prompt)\n",
    "\n",
    "if hasattr(response, \"model_dump\"):\n",
    "    print(response.model_dump())\n",
    "else:\n",
    "    print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1949c2d-0339-40c2-8335-365017633f2a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
